package com.omega.example.transformer.tokenizer.bertTokenizer;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.omega.common.utils.JsonUtils;
import com.omega.example.transformer.utils.LagJsonReader;


/**
 * Constructs a BERT tokenizer. Based on WordPiece.
 * 
 * This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which
 * contains most of the methods. Users should refer to the superclass for more
 * information regarding methods.
 * 
 * Args:
 * 
 * vocab_file (:obj:`string`): File containing the vocabulary.
 * 
 * do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to
 * lowercase the input when tokenizing.
 * 
 * do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether
 * to do basic tokenization before WordPiece.
 * 
 * never_split (:obj:`bool`, `optional`, defaults to :obj:`True`): List of
 * tokens which will never be split during tokenization. Only has an effect when
 * :obj:`do_basic_tokenize=True`
 * 
 * unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): The unknown
 * token. A token that is not in the vocabulary cannot be converted to an ID and
 * is set to be this token instead.
 * 
 * sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): The separator
 * token, which is used when building a sequence from multiple sequences, e.g.
 * two sequences for sequence classification or for a text and a question for
 * question answering. It is also used as the last token of a sequence built
 * with special tokens.
 * 
 * pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): The token used
 * for padding, for example when batching sequences of different lengths.
 * 
 * cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): The classifier
 * token which is used when doing sequence classification (classification of the
 * whole sequence instead of per-token classification). It is the first token of
 * the sequence when built with special tokens.
 * 
 * mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): The token used
 * for masking values. This is the token used when training this model with
 * masked language modeling. This is the token which the model will try to
 * predict.
 * 
 * tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
 * Whether to tokenize Chinese characters. This should likely be deactivated for
 * Japanese: see: https://github.com/huggingface/transformers/issues/328
 */

public class BertTokenizer implements Tokenizer {

	private String vocab_file = "vocab.txt";
	private Map<String, Integer> token_id_map;
	private Map<Integer, String> id_token_map;
	private boolean do_lower_case = true;
	private boolean do_basic_tokenize = true;
	private List<String> never_split = new ArrayList<String>();
	public String unk_token = "[UNK]";
	public String sep_token = "[SEP]";
	public String pad_token = "[PAD]";
	public String cls_token = "[CLS]";
	public String mask_token = "[MASK]";
	private boolean tokenize_chinese_chars = true;
	private BasicTokenizer basic_tokenizer;
	private WordpieceTokenizer wordpiece_tokenizer;
	
	public int sos;
	public int eos;
	public int pad;
	public int unk;
	
	private final static String[] _patterns = new String[]{"\\'", "\\\"", "\\.", "<br />", "\\,", "\\(", "\\)", "\\!", "\\?", "\\;", "\\:", "\\s+", "\\r", "\n"};

	private final static String[] _replacements = new String[] {" '  ", "", " . ", " ", " , ", " ( ", " ) ", " ! ", " ? ", " ", " ", " ","",""};

	private static final int MAX_LEN = 512;
	
	public BertTokenizer(String vocab_file, boolean do_lower_case, boolean do_basic_tokenize, List<String> never_split,boolean tokenize_chinese_chars) {
		this.vocab_file = vocab_file;
		this.do_lower_case = do_lower_case;
		this.do_basic_tokenize = do_basic_tokenize;
		this.never_split = never_split;
		this.tokenize_chinese_chars = tokenize_chinese_chars;
		init(vocab_file);
	}
	
	public BertTokenizer(String vocab_file, boolean do_lower_case,boolean tokenize_chinese_chars) {
		this.vocab_file = vocab_file;
		this.do_lower_case = do_lower_case;
		this.tokenize_chinese_chars = tokenize_chinese_chars;
		init(vocab_file);
	}
	
	public BertTokenizer(String vocab_file, boolean do_lower_case, boolean do_basic_tokenize, List<String> never_split,
			String unk_token, String sep_token, String pad_token, String cls_token, String mask_token,
			boolean tokenize_chinese_chars) {
		this.vocab_file = vocab_file;
		this.do_lower_case = do_lower_case;
		this.do_basic_tokenize = do_basic_tokenize;
		this.never_split = never_split;
		this.unk_token = unk_token;
		this.sep_token = sep_token;
		this.pad_token = pad_token;
		this.cls_token = cls_token;
		this.mask_token = mask_token;
		this.tokenize_chinese_chars = tokenize_chinese_chars;
		init(vocab_file);
	}

	public BertTokenizer() {
		init();
	}

	private void init() {
		try {
			this.token_id_map = load_vocab(vocab_file);
			this.eos = this.token_id_map.get(cls_token);
			this.sos = this.token_id_map.get(sep_token);
			this.pad = this.token_id_map.get(pad_token);
			this.unk = this.token_id_map.get(unk_token);
		} catch (IOException e) {
			e.printStackTrace();
		}
		this.id_token_map = new HashMap<Integer, String>();
		for (String key : token_id_map.keySet()) {
			this.id_token_map.put(token_id_map.get(key), key);
		}

		if (do_basic_tokenize) {
			this.basic_tokenizer = new BasicTokenizer(do_lower_case, never_split, tokenize_chinese_chars);
		}
		this.wordpiece_tokenizer = new WordpieceTokenizer(token_id_map, unk_token);
	}
	
	private void init(String vocab_file) {
		try {
			this.token_id_map = load_vocab_from_path(vocab_file);
			this.eos = this.token_id_map.get(cls_token);
			this.sos = this.token_id_map.get(sep_token);
			this.pad = this.token_id_map.get(pad_token);
			this.unk = this.token_id_map.get(unk_token);
		} catch (IOException e) {
			e.printStackTrace();
		}
		this.id_token_map = new HashMap<Integer, String>();
		for (String key : token_id_map.keySet()) {
			this.id_token_map.put(token_id_map.get(key), key);
		}

		if (do_basic_tokenize) {
			this.basic_tokenizer = new BasicTokenizer(do_lower_case, never_split, tokenize_chinese_chars);
		}
		this.wordpiece_tokenizer = new WordpieceTokenizer(token_id_map, unk_token);
	}

	private Map<String, Integer> load_vocab_from_path(String vocab_file_name) throws IOException {
		FileInputStream file = new FileInputStream(vocab_file_name);
		return TokenizerUtils.generateTokenIdMap(file);
	}
	
	private Map<String, Integer> load_vocab(String vocab_file_name) throws IOException {
		ClassLoader classloader = Thread.currentThread().getContextClassLoader();
		InputStream file =classloader.getResourceAsStream(vocab_file_name);
		return TokenizerUtils.generateTokenIdMap(file);
	}

	/**
	 * Tokenizes a piece of text into its word pieces.
	 *
	 * This uses a greedy longest-match-first algorithm to perform tokenization
	 * using the given vocabulary.
	 *
	 * For example: input = "unaffable" output = ["un", "##aff", "##able"]
	 *
	 * Args: text: A single token or whitespace separated tokens. This should have
	 * already been passed through `BasicTokenizer`.
	 *
	 * Returns: A list of wordpiece tokens.
	 * 
	 */
	@Override
	public List<String> tokenize(String text) {
		List<String> split_tokens = new ArrayList<String>();
		if (do_basic_tokenize) {
			for (String token : basic_tokenizer.tokenize(text)) {
				for (String sub_token : wordpiece_tokenizer.tokenize(token)) {
					split_tokens.add(sub_token);
				}
			}
		} else {
			split_tokens = wordpiece_tokenizer.tokenize(text);
		}
		return split_tokens;
	}

	public String convert_tokens_to_string(List<String> tokens) {
		// Converts a sequence of tokens (string) in a single string.
		return tokens.stream().map(s -> s.replace("##", "")).collect(Collectors.joining(""));
	}

	public List<Integer> convert_tokens_to_ids(List<String> tokens) {
		List<Integer> output = new ArrayList<Integer>();
		for (String s : tokens) {
			output.add(token_id_map.get(s));
		}
		return output;
	}
	
	public int[] tokens_to_ids(List<String> tokens) {
		int[] output = new int[tokens.size()];
		for (int i = 0;i<tokens.size();i++) {
			output[i] = token_id_map.get(tokens.get(i));
		}
		return output;
	}

	
	public int vocab_size() {
		return token_id_map.size();
	}
	
	public int[] encode(String text) {
		List<String> tokens = this.tokenize(text);
		return tokens_to_ids(tokens);
	}
	
	public String decode(int[] idx) {
		List<String> tokens = new ArrayList<String>();
		for(int i = 0;i<idx.length;i++) {
			tokens.add(id_token_map.get(idx[i]));
		}
		return convert_tokens_to_string(tokens);
	}
	
	public void encodeDataset(String dataPath,String outputPath) {
    	
    	try {

        	List<Map<String, String>> list = LagJsonReader.readJsonFileSamll(dataPath);
    		
    		String strTmp = "";
    		
    		File file = new File(outputPath);
    		FileWriter writer = new FileWriter(file);
           
    		for(int i = 0;i<list.size();i++) {
    			strTmp = list.get(i).get("completion");	
    			for(int p = 0;p<_patterns.length;p++) {
            		strTmp = strTmp.replaceAll(_patterns[p], _replacements[p]);
    			}
    			if(!strTmp.equals(" ") && !strTmp.equals("")) {
    				String idxStr = "";
    				int[] idx = encode(strTmp);
    				for(int id:idx) {
    					idxStr += id + " ";
    				}
    				writer.write(idxStr + "\n");
            	}
    			System.out.println(i);
    		}
        	
    		 writer.close();

             System.out.println("Data has been written to the file.");
             
		} catch (Exception e) {
			// TODO: handle exception
			e.printStackTrace();
		}
    	
    }
	
	public void encodeMedicalDataset(String dataPath,String outputPath) {
    	
    	try {
    		
        	List<Map<String, String>> list = LagJsonReader.readRowJsonFile(dataPath);
    		
    		String strTmp = "";
    		
    		File file = new File(outputPath);
    		FileWriter writer = new FileWriter(file);
           
    		for(int i = 0;i<list.size();i++) {
    			strTmp = list.get(i).get("text");	
    			for(int p = 0;p<_patterns.length;p++) {
            		strTmp = strTmp.replaceAll(_patterns[p], _replacements[p]);
    			}
    			if(!strTmp.equals(" ") && !strTmp.equals("")) {
    				String idxStr = "";
    				int[] idx = encode(strTmp);
    				for(int id:idx) {
    					idxStr += id + " ";
    				}
    				writer.write(idxStr + "\n");
            	}
    			System.out.println(i);
    		}
        	
    		 writer.close();

             System.out.println("Data has been written to the file.");
             
		} catch (Exception e) {
			// TODO: handle exception
			e.printStackTrace();
		}
    	
    }
	
	public void encodeBaiKeDataset(String dataPath,String outputPath) {

		try {
			Map<String,Object> once = new HashMap<String,Object>();
			File file = new File(outputPath);
    		FileWriter writer = new FileWriter(file);
           
		    FileReader fileReader = new FileReader(dataPath);
		    BufferedReader bufferedReader = new BufferedReader(fileReader);
		    String line;
		    String strTmp = "";
		    int i = 0;
		    while ((line = bufferedReader.readLine()) != null) {
		    	once = JsonUtils.gson.fromJson(line, HashMap.class);
		    	List<Map<String,Object>> sections = (List<Map<String, Object>>) once.get("sections");
		    	if(once.get("summary") != null && !once.get("summary").toString().equals("")) {
	    			strTmp = once.get("title").toString() + "： " +  once.get("summary").toString();
	    		}else {
	    			if(sections.size() > 0) {
	    				strTmp = once.get("title").toString();
			    	}
	    		}

		    	for(Map<String,Object> os:sections) {
	    			String content = os.get("content").toString();
	    			strTmp += os.get("title").toString() + "：" + content + "。";
	    		}
		    	
		    	for(int p = 0;p<_patterns.length;p++) {
		    		strTmp = strTmp.replaceAll(_patterns[p], _replacements[p]);
	        	}	
		    	
				if(!strTmp.equals(" ") && !strTmp.equals("")) {
					strTmp.replaceAll(" ", "");
	        	}
				
    			for(int p = 0;p<_patterns.length;p++) {
            		strTmp = strTmp.replaceAll(_patterns[p], _replacements[p]);
    			}
    			if(!strTmp.equals(" ") && !strTmp.equals("")) {
    				String idxStr = "";
    				int[] idx = encode(strTmp);
    				for(int id:idx) {
    					idxStr += id + " ";
    				}
    				writer.write(idxStr + "\n");
            	}
    			System.out.println(i);
    			i++;
		    }
		    bufferedReader.close();
		    writer.close();
		} catch (IOException e) {
		    e.printStackTrace();
		}

        System.out.println("Data has been written to the file.");
         
    }
	
	public void mergeData(String[] paths,String outpath) throws IOException {

		File file = new File(outpath);
		FileWriter writer = new FileWriter(file);
       
		for(String path:paths) {

			try (FileReader fileReader = new FileReader(path);
				 BufferedReader bufferedReader = new BufferedReader(fileReader);){
			    String line;
			    int i = 0;
			    while ((line = bufferedReader.readLine()) != null) {
			    	writer.write(line + "\n");
			    	System.out.println(i);
			    	i++;
			    }
			    bufferedReader.close();
			    
			} catch (Exception e) {
				// TODO: handle exception
			}
			
		}
		
		writer.close();
		
	}
	
	public static void main(String[] args) {
		
		try {
			
			String vocab_file = "H:\\transformer_dataset\\vocab.txt";
			boolean do_lower_case = true;
			boolean tokenize_chinese_chars = true;
			
			BertTokenizer tokenizer = new BertTokenizer(vocab_file, do_lower_case, tokenize_chinese_chars);
			
//			String datasetPath = "H:\\transformer_dataset\\wikipedia-cn-20230720-filtered.json";
//			String outputPath = "H:\\transformer_dataset\\wiki_idx_smallvocab.txt";
			
//			String datasetPath = "H:\\transformer_dataset\\train_encyclopedia.json";
//			String outputPath = "H:\\transformer_dataset\\medical_idx_smallvocab.txt";
			
//			String datasetPath = "H:\\transformer_dataset\\563w_baidubaike.json";
//			String outputPath = "H:\\transformer_dataset\\baike_idx_smallvocab.txt";
			
//			tokenizer.encodeBaiKeDataset(datasetPath, outputPath);		
			
//			String[] paths = new String[] {
//					"H:\\transformer_dataset\\wiki_idx_smallvocab.txt",
//					"H:\\transformer_dataset\\medical_idx_smallvocab.txt",
//					"H:\\transformer_dataset\\baike_idx_smallvocab.txt"
//			};
//			
//			String outpath = "H:\\transformer_dataset\\wbm_idx_smallvocab.txt";
			
			String[] paths = new String[] {
					"H:\\transformer_dataset\\wiki_idx_chatglm_vocab.txt",
					"H:\\transformer_dataset\\medical_idx_chatglm_vocab.txt",
					"H:\\transformer_dataset\\baike_idx_chatglm_vocab.txt"
			};
			
			String outpath = "H:\\transformer_dataset\\wbm_idx_chatglm_vocab.txt";
			
			tokenizer.mergeData(paths, outpath);
			
		} catch (Exception e) {
			// TODO: handle exception
			e.printStackTrace();
		}
		
	}
	
}
